#!/usr/bin/env python3

import os
import json
from statistics import mean
from collections import Counter
import pandas as pd
import traceback

# ==============================
# CONFIGURATION
# ==============================

ROOT_DIR = "json_output"
OUTPUT_CSV = "memory_feature_vectors.csv"
NA_VALUE = "NotApplicable"   # or "#NA"
VERBOSE = True

# ==============================
# GLOBAL FEATURE SCHEMA
# ==============================

FEATURE_COLUMNS = [
    "Filename","os_type",
    "pslist.nproc","pslist.nppid","pslist.avg_threads",
    "pslist.nprocs64bit","pslist.avg_handlers",
    "dlllist.ndlls","dlllist.avg_dlls_per_proc",
    "handles.nhandles","handles.avg_handles_per_proc",
    "handles.nport","handles.nfile","handles.nevent",
    "handles.ndesktop","handles.nkey","handles.nthread",
    "handles.ndirectory","handles.nsemaphore","handles.ntimer",
    "handles.nsection","handles.nmutant",
    "ldrmodules.not_in_load","ldrmodules.not_in_init","ldrmodules.not_in_mem",
    "ldrmodules.not_in_load_avg","ldrmodules.not_in_init_avg","ldrmodules.not_in_mem_avg",
    "malfind.ninjections","malfind.commitCharge","malfind.uniqueInjections",
    "psxview.not_in_pslist","psxview.not_in_eprocess_pool","psxview.not_in_ethread_pool",
    "psxview.not_in_pspcid_list","psxview.not_in_csrss_handles",
    "psxview.not_in_session","psxview.not_in_deskthrd",
    "psxview.not_in_pslist_false_avg","psxview.not_in_eprocess_pool_false_avg",
    "psxview.not_in_ethread_pool_false_avg","psxview.not_in_pspcid_list_false_avg",
    "psxview.not_in_csrss_handles_false_avg","psxview.not_in_session_false_avg",
    "psxview.not_in_deskthrd_false_avg",
    "modules.nmodules",
    "svcscan.nservices","svcscan.kernel_drivers","svcscan.fs_drivers",
    "svcscan.process_services","svcscan.shared_process_services",
    "svcscan.interactive_process_services","svcscan.nactive",
    "callbacks.ncallbacks","callbacks.nanonymous","callbacks.ngeneric"
]

# ==============================
# UTILITIES
# ==============================

def log(msg):
    if VERBOSE:
        print(msg)

def load_rows(path):
    if not path or not os.path.isfile(path):
        return []

    try:
        with open(path) as f:
            data = json.load(f)

        # Volatility JSON can be either:
        # 1) a list  -> return directly
        # 2) a dict with "rows"
        if isinstance(data, list):
            return data

        if isinstance(data, dict):
            return data.get("rows", [])

        return []

    except Exception as e:
        log(f"[!] Failed to parse JSON: {path} ({e})")
        return []

def avg(values):
    clean = [v for v in values if isinstance(v, (int, float))]
    return mean(clean) if clean else 0

# ==============================
# FEATURE EXTRACTORS
# ==============================

def pslist_features(rows):
    return {
        "pslist.nproc": len(rows),
        "pslist.nppid": len(set(r.get("PPID") for r in rows if r.get("PPID") is not None)),
        "pslist.avg_threads": avg([r.get("Threads", 0) for r in rows]),
        "pslist.nprocs64bit": sum(1 for r in rows if r.get("Wow64") is False),
        "pslist.avg_handlers": avg([r.get("Handles", 0) for r in rows])
    }

def dlllist_features(rows, nproc):
    return {
        "dlllist.ndlls": len(rows),
        "dlllist.avg_dlls_per_proc": (len(rows) / nproc) if nproc else 0
    }

def handles_features(rows, nproc):
    c = Counter(r.get("Type") for r in rows)
    return {
        "handles.nhandles": len(rows),
        "handles.avg_handles_per_proc": (len(rows) / nproc) if nproc else 0,
        "handles.nfile": c.get("File", 0),
        "handles.nkey": c.get("Key", 0),
        "handles.nthread": c.get("Thread", 0),
        "handles.nsection": c.get("Section", 0),
        "handles.nmutant": c.get("Mutant", 0),
        "handles.ndirectory": c.get("Directory", 0),
        "handles.nevent": c.get("Event", 0),
        "handles.ntimer": c.get("Timer", 0),
        "handles.nsemaphore": c.get("Semaphore", 0),
        "handles.ndesktop": c.get("Desktop", 0),
        "handles.nport": c.get("Port", 0)
    }

def ldrmodules_features(rows, nproc):
    not_load = sum(1 for r in rows if r.get("InLoad") is False)
    not_init = sum(1 for r in rows if r.get("InInit") is False)
    not_mem  = sum(1 for r in rows if r.get("InMem")  is False)
    return {
        "ldrmodules.not_in_load": not_load,
        "ldrmodules.not_in_init": not_init,
        "ldrmodules.not_in_mem": not_mem,
        "ldrmodules.not_in_load_avg": (not_load / nproc) if nproc else 0,
        "ldrmodules.not_in_init_avg": (not_init / nproc) if nproc else 0,
        "ldrmodules.not_in_mem_avg": (not_mem / nproc) if nproc else 0
    }

def malfind_features(rows):
    return {
        "malfind.ninjections": len(rows),
        "malfind.commitCharge": sum(r.get("CommitCharge", 0) for r in rows),
        "malfind.uniqueInjections": len(set(r.get("PID") for r in rows))
    }

def psxview_features(rows, nproc):
    flags = [
        "pslist","eprocess_pool","ethread_pool",
        "pspcid_list","csrss_handles","session","deskthrd"
    ]
    out = {}
    for f in flags:
        false_count = sum(1 for r in rows if r.get(f) is False)
        out[f"psxview.not_in_{f}"] = false_count
        out[f"psxview.not_in_{f}_false_avg"] = (false_count / nproc) if nproc else 0
    return out

# ==============================
# CORE PROCESSING
# ==============================

def finalize_vector(raw, os_type):
    vec = {}
    for col in FEATURE_COLUMNS:
        if col in raw:
            vec[col] = raw[col]
        else:
            vec[col] = 0 if os_type == "windows" else NA_VALUE
    return vec

def process_dump(os_type, dump_id, plugins):
    try:
        log(f"    ↳ Processing dump: {dump_id}")
        features = {"Filename": dump_id, "os_type": os_type}

        ps = load_rows(plugins.get("pslist"))
        features.update(pslist_features(ps))
        nproc = features.get("pslist.nproc", 0)

        if os_type == "windows":
            features.update(dlllist_features(load_rows(plugins.get("dlllist")), nproc))
            features.update(handles_features(load_rows(plugins.get("handles")), nproc))
            features.update(ldrmodules_features(load_rows(plugins.get("ldrmodules")), nproc))
            features.update(psxview_features(load_rows(plugins.get("psxview")), nproc))
            features.update(malfind_features(load_rows(plugins.get("malfind"))))
            features["modules.nmodules"] = len(load_rows(plugins.get("modules")))
            features["svcscan.nservices"] = len(load_rows(plugins.get("svcscan")))
            features["callbacks.ncallbacks"] = len(load_rows(plugins.get("callbacks")))

        return finalize_vector(features, os_type)

    except Exception:
        log(f"[!] ERROR while processing {dump_id}")
        traceback.print_exc()
        return None

# ==============================
# MAIN
# ==============================

rows = []
total_dumps = 0

if not os.path.isdir(ROOT_DIR):
    print(f"[✗] Root directory not found: {ROOT_DIR}")
    exit(1)

for os_type in ["windows", "linux", "mac"]:
    os_dir = os.path.join(ROOT_DIR, os_type)
    if not os.path.isdir(os_dir):
        log(f"[!] Skipping missing OS folder: {os_dir}")
        continue

    log(f"[+] Scanning OS folder: {os_type}")

    dumps = {}
    for root, _, files in os.walk(os_dir):
        plugin = os.path.basename(root)
        for f in files:
            if f.endswith(".json"):
                dump_id = f.split(".", 1)[0]
                dumps.setdefault(dump_id, {})[plugin] = os.path.join(root, f)

    log(f"    Found {len(dumps)} dumps under {os_type}")

    for dump_id, plugins in dumps.items():
        total_dumps += 1
        row = process_dump(os_type, dump_id, plugins)
        if row:
            rows.append(row)

# ==============================
# WRITE CSV
# ==============================

df = pd.DataFrame(rows, columns=FEATURE_COLUMNS)
df = df.sort_values(by=["os_type","Filename"])
df.to_csv(OUTPUT_CSV, index=False)

print(f"[✓] CSV written: {OUTPUT_CSV}")
print(f"[✓] Dumps processed: {len(rows)} / {total_dumps}")

